# -*- coding:utf-8 -*-

import importlib
import tensorflow as tf
from dto.dataset_op_dto import DatasetOpDto
from dto.devide_dataset_dto import DividedDataSetDto
from config.glob.global_pool import global_pool
from dto.dataset_handle_dto import DatasetHandleDto
from utils.log_utils import log_debug


"""
使用数据集构建成tf.data.Dataset格式
为train、validation、test每个构建一个Dataset
"""


class Parser:
    def __init__(self, handle_dto, divided_dataset: DividedDataSetDto):
        if not handle_dto:
            train_func = None
            validation_func = None
            test_func = None
        else:
            train_func = handle_dto.train_func
            validation_func = handle_dto.validation_func
            test_func = handle_dto.test_func

        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        # train数据集
        self.train = DatasetOpDto(self.sess, train_func)
        # validation数据集
        self.validation = DatasetOpDto(self.sess, validation_func)
        # test数据集
        self.test = DatasetOpDto(self.sess, test_func)
        self.build_dataset(divided_dataset)

    def build_train(self, num, features, labels, batch_size):
        """
        构建train数据集
        :param num:
        :param batch_size:
        :param features:
        :param labels:
        :return:
        """
        self.train.batch_size = batch_size
        self.train.num = num
        self.train.build_dataset(features, labels, is_train=True)
        self.train.build_iterator()

    def build_validation(self, num, features, labels):
        """
        构建validation数据集
        :param num:
        :param features:
        :param labels:
        :return:
        """
        self.validation.batch_size = num
        self.validation.num = num
        self.validation.build_dataset(features, labels, is_train=False)
        self.validation.build_iterator()

    def build_test(self, num, features, labels):
        """
        构建test数据集
        :param num:
        :param features:
        :param labels:
        :return:
        """
        self.test.batch_size = num
        self.test.num = num
        self.test.build_dataset(features, labels, is_train=False)
        self.test.build_iterator()

    def build_dataset(self, ori_data):
        """
        train、validation和test数据集 构建方法
        :param ori_data:
        :return:
        """
        with self.graph.as_default():
            if ori_data.train_num:
                self.build_train(
                    ori_data.train_num, ori_data.train_features,
                    ori_data.train_labels, global_pool.config.batch_size
                )
            if ori_data.validation_num:
                self.build_validation(
                    ori_data.validation_num, ori_data.validation_features, ori_data.validation_labels
                )
            if ori_data.test_num:
                self.build_test(
                    ori_data.test_num, ori_data.test_features, ori_data.test_labels
                )


def parse(data):
    """
    构建dataset
    :param data:
    :return:
    """
    # 加载配置和dataset自定义方法
    if global_pool.config.ds_parser.use_handle:
        handle_module_str = global_pool.config.ds_parser.handle  # 解析器路径
        handle_module = importlib.import_module(handle_module_str)
        # 封装自定义处理方法
        handle_dto = DatasetHandleDto(handle_module.handle, handle_module.handle, handle_module.handle)
    else:
        handle_dto = None
        log_debug('\033[32mno ds_parser.handle config, do not load ds_parser.handle func\033[0m')
    # 构建解析器
    global_pool.dataset_op = Parser(handle_dto, data)
